{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import time\n",
    "import torch\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "\n",
    "from safetensors.torch import save_file\n",
    "from src.pipeline import FluxPipeline\n",
    "from src.transformer_flux import FluxTransformer2DModel\n",
    "from src.lora_helper import set_single_lora, set_multi_lora, unset_lora\n",
    "\n",
    "torch.cuda.set_device(1)\n",
    "\n",
    "class ImageProcessor:\n",
    "    def __init__(self, path):\n",
    "        device = \"cuda\"\n",
    "        self.pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.bfloat16, device=device)\n",
    "        transformer = FluxTransformer2DModel.from_pretrained(path, subfolder=\"transformer\",torch_dtype=torch.bfloat16, device=device)\n",
    "        self.pipe.transformer = transformer\n",
    "        self.pipe.to(device)\n",
    "        \n",
    "    def clear_cache(self, transformer):\n",
    "        for name, attn_processor in transformer.attn_processors.items():\n",
    "            attn_processor.bank_kv.clear()\n",
    "        \n",
    "    def process_image(self, prompt='', subject_imgs=[], spatial_imgs=[], height = 768, width = 768, output_path=None, seed=42):\n",
    "        if len(spatial_imgs)>0:\n",
    "            spatial_ls = [Image.open(image_path).convert(\"RGB\") for image_path in spatial_imgs]\n",
    "        else:\n",
    "            spatial_ls = []\n",
    "        if len(subject_imgs)>0:\n",
    "            subject_ls = [Image.open(image_path).convert(\"RGB\") for image_path in subject_imgs]\n",
    "        else:\n",
    "            subject_ls = []\n",
    "\n",
    "        prompt = prompt\n",
    "        image = self.pipe(\n",
    "            prompt,\n",
    "            height=int(height),\n",
    "            width=int(width),\n",
    "            guidance_scale=3.5,\n",
    "            num_inference_steps=25,\n",
    "            max_sequence_length=512,\n",
    "            generator=torch.Generator(\"cpu\").manual_seed(seed), \n",
    "            subject_images=subject_ls,\n",
    "            spatial_images=spatial_ls,\n",
    "            cond_size=512,\n",
    "        ).images[0]\n",
    "        self.clear_cache(self.pipe.transformer)\n",
    "        image.show()\n",
    "        if output_path:\n",
    "            image.save(output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### models path ###\n",
    "# spatial model\n",
    "base_path = \"FLUX.1-dev\"  # your flux model path\n",
    "lora_path = \"./models\" # your lora folder path\n",
    "canny_path = lora_path + \"/canny.safetensors\"\n",
    "depth_path = lora_path + \"/depth.safetensors\"\n",
    "openpose_path = lora_path + \"/pose.safetensors\"\n",
    "inpainting_path = lora_path + \"/inpainting.safetensors\"\n",
    "hedsketch_path = lora_path + \"/hedsketch.safetensors\"\n",
    "seg_path = lora_path + \"/seg.safetensors\"\n",
    "# subject model\n",
    "subject_path = lora_path + \"/subject.safetensors\"\n",
    "\n",
    "# init image processor\n",
    "processor = ImageProcessor(base_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "for single condition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set lora\n",
    "path = depth_path  # single control model path\n",
    "lora_weights=[1]  # lora weights for each control model\n",
    "set_single_lora(processor.pipe.transformer, path, lora_weights=lora_weights,cond_size=512)\n",
    "\n",
    "# infer\n",
    "prompt='a cafe bar'\n",
    "spatial_imgs=[\"./test_imgs/depth.png\"]\n",
    "height = 1024\n",
    "width = 1024\n",
    "processor.process_image(prompt=prompt, subject_imgs=[], spatial_imgs=spatial_imgs, height=height, width=width, seed=11)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "for multi condition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set lora\n",
    "paths = [subject_path, inpainting_path]  # multi control model paths\n",
    "lora_weights=[[1],[1]]  # lora weights for each control model\n",
    "set_multi_lora(processor.pipe.transformer, paths, lora_weights=lora_weights, cond_size=512)\n",
    "\n",
    "# infer\n",
    "prompt='A SKS on the car'\n",
    "spatial_imgs=[\"./test_imgs/subject_1.png\"]\n",
    "subject_imgs=[\"./test_imgs/inpainting.png\"]\n",
    "height = 1024\n",
    "width = 1024\n",
    "processor.process_image(prompt=prompt, subject_imgs=subject_imgs, spatial_imgs=spatial_imgs, height=height, width=width, seed=42)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "zyxdit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
